import os
import torch
import platform
import subprocess
from colorama import Fore, Style
from tempfile import NamedTemporaryFile
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import torch, gc

gc.collect()
torch.cuda.empty_cache()


def init_model():
    print("init model ...")
    model = AutoModelForCausalLM.from_pretrained(
       "/home/kean/models/Baichuan2-13B-Chat",
       torch_dtype=torch.float16,
       device_map="auto",
       trust_remote_code=True
    )
    # model = AutoModelForCausalLM.from_pretrained("/home/kean/models/Baichuan2-13B-Chat", torch_dtype=torch.float16, trust_remote_code=True)

    print("init model1 ...")
    model.generation_config = GenerationConfig.from_pretrained(
        "/home/kean/models/Baichuan2-13B-Chat"
    )
    print("init model2 ...")
    tokenizer = AutoTokenizer.from_pretrained(
        "/home/kean/models/Baichuan2-13B-Chat",
        use_fast=False,
        trust_remote_code=True
    )
    print("init model3 ...")
    return model, tokenizer


def clear_screen():
    if platform.system() == "Windows":
        os.system("cls")
    else:
        os.system("clear")
    print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型，输入进行对话，vim 多行输入，clear 清空历史，CTRL+C 中断生成，stream 开关流式生成，exit 结束。")
    return []


def vim_input():
    with NamedTemporaryFile() as tempfile:
        tempfile.close()
        subprocess.call(['vim', '+star', tempfile.name])
        text = open(tempfile.name).read()
    return text


def main(stream=True):
    model, tokenizer = init_model()
    messages = clear_screen()
    while True:
        gc.collect()
        torch.cuda.empty_cache()
        prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户：" + Style.NORMAL)
        if prompt.strip() == "exit":
            break
        if prompt.strip() == "clear":
            messages = clear_screen()
            continue
        if prompt.strip() == 'vim':
            prompt = vim_input()
            print(prompt)
        print(Fore.CYAN + Style.BRIGHT + "\nBaichuan 2：" + Style.NORMAL, end='')
        if prompt.strip() == "stream":
            stream = not stream
            print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
            continue
        messages.append({"role": "user", "content": prompt})
        if stream:
            position = 0
            try:
                for response in model.chat(tokenizer, messages, stream=True):
                    print(response[position:], end='', flush=True)
                    position = len(response)
                    if torch.backends.mps.is_available():
                        torch.mps.empty_cache()
            except KeyboardInterrupt:
                pass
            print()
        else:
            response = model.chat(tokenizer, messages)
            print(response)
            if torch.backends.mps.is_available():
                torch.mps.empty_cache()
        messages.append({"role": "assistant", "content": response})
    print(Style.RESET_ALL)


if __name__ == "__main__":
    main()
